{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Prompt tuning for causal language modeling"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Prompting helps guide language model behavior by adding some input text specific to a task. Prompt tuning is an additive method for only training and updating the newly added prompt tokens to a pretrained model. This way, you can use one pretrained model whose weights are frozen, and train and update a smaller set of prompt parameters for each downstream task instead of fully finetuning a separate model. As models grow larger and larger, prompt tuning can be more efficient, and results are even better as model parameters scale.\n",
    "\n",
    "<Tip>\n",
    "\n",
    "💡 Read [The Power of Scale for Parameter-Efficient Prompt Tuning](https://arxiv.org/abs/2104.08691) to learn more about prompt tuning.\n",
    "\n",
    "</Tip>\n",
    "\n",
    "This guide will show you how to apply prompt tuning to train a [`bloomz-560m`](https://huggingface.co/bigscience/bloomz-560m) model on the `twitter_complaints` subset of the [RAFT](https://huggingface.co/datasets/ought/raft) dataset.\n",
    "\n",
    "Before you begin, make sure you have all the necessary libraries installed:\n",
    "\n",
    "```bash\n",
    "!pip install -q peft transformers datasets\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Start by defining the model and tokenizer, the dataset and the dataset columns to train on, some training hyperparameters, and the [PromptTuningConfig](https://huggingface.co/docs/peft/main/en/package_reference/tuners#peft.PromptTuningConfig). The [PromptTuningConfig](https://huggingface.co/docs/peft/main/en/package_reference/tuners#peft.PromptTuningConfig) contains information about the task type, the text to initialize the prompt embedding, the number of virtual tokens, and the tokenizer to use:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator, get_linear_schedule_with_warmup\n",
    "from peft import get_peft_config, get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, PeftType\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "import os\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "\n",
    "device = \"cuda\"\n",
    "model_name_or_path = \"bigscience/bloomz-560m\"\n",
    "tokenizer_name_or_path = \"bigscience/bloomz-560m\"\n",
    "peft_config = PromptTuningConfig(\n",
    "    task_type=TaskType.CAUSAL_LM,\n",
    "    prompt_tuning_init=PromptTuningInit.TEXT,\n",
    "    num_virtual_tokens=8,\n",
    "    prompt_tuning_init_text=\"Classify if the tweet is a complaint or not:\",\n",
    "    tokenizer_name_or_path=model_name_or_path,\n",
    ")\n",
    "\n",
    "dataset_name = \"twitter_complaints\"\n",
    "checkpoint_name = f\"{dataset_name}_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}_v1.pt\".replace(\n",
    "    \"/\", \"_\"\n",
    ")\n",
    "text_column = \"Tweet text\"\n",
    "label_column = \"text_label\"\n",
    "max_length = 64\n",
    "lr = 3e-2\n",
    "num_epochs = 50\n",
    "batch_size = 8"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For this guide, you'll load the `twitter_complaints` subset of the [RAFT](https://huggingface.co/datasets/ought/raft) dataset. This subset contains tweets that are labeled either `complaint` or `no complaint`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = load_dataset(\"ought/raft\", dataset_name)\n",
    "dataset[\"train\"][0]\n",
    "{\"Tweet text\": \"@HMRCcustomers No this is my first job\", \"ID\": 0, \"Label\": 2}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To make the `Label` column more readable, replace the `Label` value with the corresponding label text and store them in a `text_label` column. You can use the [map](https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Dataset.map) function to apply this change over the entire dataset in one step:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "classes = [k.replace(\"_\", \" \") for k in dataset[\"train\"].features[\"Label\"].names]\n",
    "dataset = dataset.map(\n",
    "    lambda x: {\"text_label\": [classes[label] for label in x[\"Label\"]]},\n",
    "    batched=True,\n",
    "    num_proc=1,\n",
    ")\n",
    "dataset[\"train\"][0]\n",
    "{\"Tweet text\": \"@HMRCcustomers No this is my first job\", \"ID\": 0, \"Label\": 2, \"text_label\": \"no complaint\"}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preprocess dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, you'll setup a tokenizer; configure the appropriate padding token to use for padding sequences, and determine the maximum length of the tokenized labels:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)\n",
    "if tokenizer.pad_token_id is None:\n",
    "    tokenizer.pad_token_id = tokenizer.eos_token_id\n",
    "target_max_length = max([len(tokenizer(class_label)[\"input_ids\"]) for class_label in classes])\n",
    "print(target_max_length)\n",
    "3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create a `preprocess_function` to:\n",
    "\n",
    "1. Tokenize the input text and labels.\n",
    "2. For each example in a batch, pad the labels with the tokenizers `pad_token_id`.\n",
    "3. Concatenate the input text and labels into the `model_inputs`.\n",
    "4. Create a separate attention mask for `labels` and `model_inputs`.\n",
    "5. Loop through each example in the batch again to pad the input ids, labels, and attention mask to the `max_length` and convert them to PyTorch tensors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess_function(examples):\n",
    "    batch_size = len(examples[text_column])\n",
    "    inputs = [f\"{text_column} : {x} Label : \" for x in examples[text_column]]\n",
    "    targets = [str(x) for x in examples[label_column]]\n",
    "    model_inputs = tokenizer(inputs)\n",
    "    labels = tokenizer(targets)\n",
    "    for i in range(batch_size):\n",
    "        sample_input_ids = model_inputs[\"input_ids\"][i]\n",
    "        label_input_ids = labels[\"input_ids\"][i] + [tokenizer.pad_token_id]\n",
    "        # print(i, sample_input_ids, label_input_ids)\n",
    "        model_inputs[\"input_ids\"][i] = sample_input_ids + label_input_ids\n",
    "        labels[\"input_ids\"][i] = [-100] * len(sample_input_ids) + label_input_ids\n",
    "        model_inputs[\"attention_mask\"][i] = [1] * len(model_inputs[\"input_ids\"][i])\n",
    "    # print(model_inputs)\n",
    "    for i in range(batch_size):\n",
    "        sample_input_ids = model_inputs[\"input_ids\"][i]\n",
    "        label_input_ids = labels[\"input_ids\"][i]\n",
    "        model_inputs[\"input_ids\"][i] = [tokenizer.pad_token_id] * (\n",
    "            max_length - len(sample_input_ids)\n",
    "        ) + sample_input_ids\n",
    "        model_inputs[\"attention_mask\"][i] = [0] * (max_length - len(sample_input_ids)) + model_inputs[\n",
    "            \"attention_mask\"\n",
    "        ][i]\n",
    "        labels[\"input_ids\"][i] = [-100] * (max_length - len(sample_input_ids)) + label_input_ids\n",
    "        model_inputs[\"input_ids\"][i] = torch.tensor(model_inputs[\"input_ids\"][i][:max_length])\n",
    "        model_inputs[\"attention_mask\"][i] = torch.tensor(model_inputs[\"attention_mask\"][i][:max_length])\n",
    "        labels[\"input_ids\"][i] = torch.tensor(labels[\"input_ids\"][i][:max_length])\n",
    "    model_inputs[\"labels\"] = labels[\"input_ids\"]\n",
    "    return model_inputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Use the [map](https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Dataset.map) function to apply the `preprocess_function` to the entire dataset. You can remove the unprocessed columns since the model won't need them:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "processed_datasets = dataset.map(\n",
    "    preprocess_function,\n",
    "    batched=True,\n",
    "    num_proc=1,\n",
    "    remove_columns=dataset[\"train\"].column_names,\n",
    "    load_from_cache_file=False,\n",
    "    desc=\"Running tokenizer on dataset\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create a [`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) from the `train` and `eval` datasets. Set `pin_memory=True` to speed up the data transfer to the GPU during training if the samples in your dataset are on a CPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = processed_datasets[\"train\"]\n",
    "eval_dataset = processed_datasets[\"test\"]\n",
    "\n",
    "\n",
    "train_dataloader = DataLoader(\n",
    "    train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True\n",
    ")\n",
    "eval_dataloader = DataLoader(eval_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You're almost ready to setup your model and start training!\n",
    "\n",
    "Initialize a base model from [AutoModelForCausalLM](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoModelForCausalLM), and pass it and `peft_config` to the `get_peft_model()` function to create a [PeftModel](https://huggingface.co/docs/peft/main/en/package_reference/peft_model#peft.PeftModel). You can print the new [PeftModel](https://huggingface.co/docs/peft/main/en/package_reference/peft_model#peft.PeftModel)'s trainable parameters to see how much more efficient it is than training the full parameters of the original model!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = AutoModelForCausalLM.from_pretrained(model_name_or_path)\n",
    "model = get_peft_model(model, peft_config)\n",
    "print(model.print_trainable_parameters())\n",
    "\"trainable params: 8192 || all params: 559222784 || trainable%: 0.0014648902430985358\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Setup an optimizer and learning rate scheduler:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = torch.optim.AdamW(model.parameters(), lr=lr)\n",
    "lr_scheduler = get_linear_schedule_with_warmup(\n",
    "    optimizer=optimizer,\n",
    "    num_warmup_steps=0,\n",
    "    num_training_steps=(len(train_dataloader) * num_epochs),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Move the model to the GPU, then write a training loop to start training!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.to(device)\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    model.train()\n",
    "    total_loss = 0\n",
    "    for step, batch in enumerate(tqdm(train_dataloader)):\n",
    "        batch = {k: v.to(device) for k, v in batch.items()}\n",
    "        outputs = model(**batch)\n",
    "        loss = outputs.loss\n",
    "        total_loss += loss.detach().float()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        lr_scheduler.step()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "    model.eval()\n",
    "    eval_loss = 0\n",
    "    eval_preds = []\n",
    "    for step, batch in enumerate(tqdm(eval_dataloader)):\n",
    "        batch = {k: v.to(device) for k, v in batch.items()}\n",
    "        with torch.no_grad():\n",
    "            outputs = model(**batch)\n",
    "        loss = outputs.loss\n",
    "        eval_loss += loss.detach().float()\n",
    "        eval_preds.extend(\n",
    "            tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True)\n",
    "        )\n",
    "\n",
    "    eval_epoch_loss = eval_loss / len(eval_dataloader)\n",
    "    eval_ppl = torch.exp(eval_epoch_loss)\n",
    "    train_epoch_loss = total_loss / len(train_dataloader)\n",
    "    train_ppl = torch.exp(train_epoch_loss)\n",
    "    print(f\"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Share model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can store and share your model on the Hub if you'd like. Log in to your Hugging Face account and enter your token when prompted:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Use the [push_to_hub](https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.PreTrainedModel.push_to_hub) function to upload your model to a model repository on the Hub:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "peft_model_id = \"your-name/bloomz-560m_PROMPT_TUNING_CAUSAL_LM\"\n",
    "model.push_to_hub(\"your-name/bloomz-560m_PROMPT_TUNING_CAUSAL_LM\", use_auth_token=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once the model is uploaded, you'll see the model file size is only 33.5kB! 🤏"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's try the model on a sample input for inference. If you look at the repository you uploaded the model to, you'll see a `adapter_config.json` file. Load this file into [PeftConfig](https://huggingface.co/docs/peft/main/en/package_reference/config#peft.PeftConfig) to specify the `peft_type` and `task_type`. Then you can load the prompt tuned model weights, and the configuration into [from_pretrained()](https://huggingface.co/docs/peft/main/en/package_reference/peft_model#peft.PeftModel.from_pretrained) to create the [PeftModel](https://huggingface.co/docs/peft/main/en/package_reference/peft_model#peft.PeftModel):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from peft import PeftModel, PeftConfig\n",
    "\n",
    "peft_model_id = \"stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM\"\n",
    "\n",
    "config = PeftConfig.from_pretrained(peft_model_id)\n",
    "model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)\n",
    "model = PeftModel.from_pretrained(model, peft_model_id)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Grab a tweet and tokenize it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = tokenizer(\n",
    "    f'{text_column} : {\"@nationalgridus I have no water and the bill is current and paid. Can you do something about this?\"} Label : ',\n",
    "    return_tensors=\"pt\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Put the model on a GPU and *generate* the predicted label:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.to(device)\n",
    "\n",
    "with torch.no_grad():\n",
    "    inputs = {k: v.to(device) for k, v in inputs.items()}\n",
    "    outputs = model.generate(\n",
    "        input_ids=inputs[\"input_ids\"], attention_mask=inputs[\"attention_mask\"], max_new_tokens=10, eos_token_id=3\n",
    "    )\n",
    "    print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))\n",
    "[\n",
    "    \"Tweet text : @nationalgridus I have no water and the bill is current and paid. Can you do something about this? Label : complaint\"\n",
    "]"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 4
}
